from interface import Interface
from tools import validate_ip_with_subnet, check_aggregate_route, check_route_proto_type, check_secure_policy, \
    cidr_to_netmask
from firewall import Firewall
from router import Router
import jsonpickle
from netmiko import ConnectHandler
import time


# 获取所有网络设备信息（路由表、接口地址等），定时任务，将其存储到硬盘中备用
def get_all_device_info():
    return


# 确认路由信息
# 获取实时路由信息
# 判断是否存在路由，存在即继续，不存在即报错退出
def check_route(origin_ip, destination_ip):
    # 找到源IP直连的设备和接口
    origin_direct_route_device, first_in_interface = find_direct_route_device_and_interface(origin_ip)
    # first_out_interface = origin_first_routing(origin_direct_route_device, destination_ip)

    # 找到目IP直连的设备
    destination_direct_route_device, last_hop_interface = find_direct_route_device_and_interface(destination_ip)
    # 从源IP的直连设备开始逐步排查，是否可以到达目IP直连的设备
    if (origin_direct_route_device is not None) & (destination_direct_route_device is not None):
        return simulate_routing(origin_direct_route_device, destination_direct_route_device, destination_ip,
                                first_in_interface, last_hop_interface)
    route_detail_list.append({"device": destination_direct_route_device, "interface_in": "", "interface_out": ""})
    return False


#  找到某地址的直连路由设备及其连接的接口，一般是路由器，可能是交换机/防火墙
def find_direct_route_device_and_interface(ip):
    # 遍历所有设备
    # print("遍历所有网络设备的路由，找寻某地址段的直连设备")
    # ”找寻某地址段的直连设备“应满足两个条件：
    #  1、check_aggregate_route通过，表示网段一致
    #  2、判断proto是否为direct，即proto_type在路由表中index应为3，文本描述为“local”
    for net_device in net_device_list:
        route_info = net_device.route_info
        contact_interface = ""
        for route_element in route_info:
            if check_aggregate_route(ip, route_element["destination"] + "/" + route_element["net_mask"]) & (
                    check_route_proto_type(route_element["proto_type"]) == "local"):
                print("设备{}为地址{}的直连网络设备".format(net_device.device_name, ip))
                # 确认接口
                interface_info = net_device.interface_info
                for interface_element in interface_info:
                    if interface_element.index == route_element["interface"]:
                        contact_interface = interface_element
                return net_device, contact_interface
    print("未找到与ip:{}直连的网络设备".format(ip))
    return None, None


def origin_first_routing(origin_device, destination_ip):
    origin_device_route_table = origin_device.route_info
    for route in origin_device_route_table:
        if check_aggregate_route(destination_ip, (route["destination"] + "/" + route["net_mask"])):
            next_hop = route["nexthop"]
            interface_index = route["interface"]
            for interface_element in origin_device.interface_info:
                if interface_element.index == interface_index:
                    out_interface__ = interface_element
                    return out_interface__


# 模拟从源到目的路由过程，看是否能成功
def simulate_routing(origin_device, destination_device, destination_ip, nexthop_interface, last_hop_interface):
    this_route_detail = {"device": origin_device, "interface_in": nexthop_interface, "interface_out": ""}
    origin_device_route_table = origin_device.route_info
    next_hop = ""
    for route in origin_device_route_table:
        if check_aggregate_route(destination_ip, (route["destination"] + "/" + route["net_mask"])):
            next_hop = route["nexthop"]
            if next_hop == "":
                print("设备 {} 中缺失去往{}的路由，路由中断".format(origin_device.device_name, destination_ip))
                return False
            else:
                next_device, next_device_interface = find_device_with_interface_ip(next_hop)
            interface_index = route["interface"]
            for interface_element in origin_device.interface_info:
                if interface_element.index == interface_index:
                    this_route_detail["interface_out"] = interface_element
    route_detail_list.append(this_route_detail)

    if next_device.device_name == destination_device.device_name:
        # 把最后的一个设备加上
        last_route_detail = destination_device
        last_route_detail = {"device": destination_device, "interface_in": next_device_interface,
                             "interface_out": last_hop_interface}
        route_detail_list.append(last_route_detail)
        return True
    else:
        return simulate_routing(next_device, destination_device, destination_ip, next_device_interface,
                                last_hop_interface)


# 给IP地址，找接口配置为该IP地址的设备
def find_device_with_interface_ip(interface_ip):
    for net_device_element in net_device_list:
        interface_info = net_device_element.interface_info
        for interface in interface_info:
            if interface_ip == interface.ip:
                print("下一跳地址--{}--位于设备--{}--的接口--{}--上".format(interface_ip, net_device_element.device_name, interface.name))
                return net_device_element, interface
    return None, None


# 人工确认生成的网络策略命令
def print_secure_policy(lines, isUpdate, editKye):
    for idx in range(len(lines)):
        parts = lines[idx].split()
        msg = ""
        if isUpdate:
            msg = "禁止修改 "
            if editKye in lines[idx]:
                msg = "允许修改 "
        if len(parts) == 2:
            print(msg + "id：" + str(idx) + "     名称：" + parts[0] + spacing(parts[0]) + "值：" + parts[1])
        elif len(parts) == 3:
            print(
                msg + "id：" + str(idx) + "     名称：" + parts[0] + parts[1] + spacing(parts[0] + parts[1]) + "值：" + parts[
                    2])
        elif len(parts) > 3:
            print(msg + "id：" + str(idx) + "-1   名称：" + parts[0] + spacing(parts[0]) + "值：" + parts[1])
            print(msg + "id：" + str(idx) + "-2   名称：" + parts[2] + spacing(parts[2]) + "值：" + parts[3])


def spacing(name):
    nameSpacing = ''
    for i in range(25 - len(name)):
        nameSpacing += ' '
    return nameSpacing


# 逐条确认安全策略是否正确
def confirm_secure_policy_individually(lines, isUpdate, editKye):
    if len(lines) > 0:
        user_input = ''
        while user_input != "Y" and user_input != "y":
            print("确认是否执行以下命令：Y 确定，输入如下格式可直接修改命令：id # 值, id # 值，如： 2#untrust,3-1#192.168.5.1,3-2#255.255.252.0 ")
            print_secure_policy(lines, isUpdate, editKye)
            user_input = input("\033[1;31m请输入：\033[0m")
            if user_input != "Y" and user_input != "y":
                try:
                    upCommands = user_input.split(",")
                    for upCommand in upCommands:
                        id = upCommand.split("#")[0]
                        if "-" in id:
                            values = lines[int(id.split("-")[0])].split()
                            if id.split("-")[1] == "1":
                                values[1] = upCommand.split("#")[-1]
                                if isUpdate:
                                    if editKye in lines[int(id.split("-")[0])]:
                                        lines[int(id.split("-")[0])] = " ".join(values)
                                    else:
                                        print("修改失败，此项禁止修改.")
                                else:
                                    lines[int(id.split("-")[0])] = " ".join(values)
                            elif id.split("-")[1] == "2":
                                values[3] = upCommand.split("#")[-1]
                                if isUpdate:
                                    if editKye in lines[int(id.split("-")[0])]:
                                        lines[int(id.split("-")[0])] = " ".join(values)
                                    else:
                                        print("修改失败，此项禁止修改")
                                else:
                                    lines[int(id.split("-")[0])] = " ".join(values)
                        else:
                            values = lines[int(id)].split()
                            values[-1] = upCommand.split("#")[-1]
                            if isUpdate:
                                if editKye in lines[int(id)]:
                                    lines[int(id)] = " ".join(values)
                                else:
                                    print("修改失败，此项禁止修改")
                            else:
                                lines[int(id)] = " ".join(values)
                except ValueError:
                    print(f"输入格式错误或对应ID不存在，请按以下格式：id # 值, id # 值...")
    return lines


# 确认生成的安全策略（新建、修改）
def confirm_secure_policy(generated_secure_policy):
    for create_secure_policy_index in range(len(generated_secure_policy)):
        print("已确认【{}】条安全策略".format(create_secure_policy_index) + "，剩余待确认【{}】条".format(
            len(generated_secure_policy) - create_secure_policy_index))
        edit_key = None
        isUpdate = False
        if "edit_key" in generated_secure_policy[create_secure_policy_index]:
            edit_key = generated_secure_policy[create_secure_policy_index].get("edit_key")
            isUpdate = True
        generated_secure_policy[create_secure_policy_index]['newCommands'] = confirm_secure_policy_individually(
            generated_secure_policy[create_secure_policy_index].get("commands").copy(), isUpdate, edit_key)
    return True, generated_secure_policy


# 确认网络策略执行情况，执行后策略列表或执行前后对比
def confirm_execute_result(execute_response_):
    print("执行返回为：\n{}\n请观察命令是否执行成功".format(execute_response_))
    return


# 确认节点流量通过情况，即判断是否生效
def confirm_network_traffic():
    print("此处应确认源目流量能否顺利通过防火墙，目前防火墙型号USG6000V不支持远程确认，请从防火墙web页面进行操作\n https://192.168.12.12:8443/")
    return


def create_all_net_device():
    global route_detail_list
    global net_device_list
    route_detail_list = []
    net_device_list = []
    router_AR5 = Router(
        device_type="Huawei_AR2220",
        device_name="AR5",
        manage_ip="192.168.13.3",
        auth_type="snmp",
        snmp_port=161,
        snmp_password="public",
        ssh_port=22,
        ssh_username="admin",
        ssh_password="123"
    )
    router_AR8 = Router(
        device_type="Huawei_AR2220",
        device_name="AR8",
        manage_ip="10.0.10.2",
        auth_type="snmp",
        snmp_port=161,
        snmp_password="public",
        ssh_port=22,
        ssh_username="admin",
        ssh_password="123"
    )
    router_AR9 = Router(
        device_type="Huawei_AR2220",
        device_name="AR9",
        manage_ip="192.168.10.2",
        auth_type="snmp",
        snmp_port=161,
        snmp_password="public",
        ssh_port=22,
        ssh_username="admin",
        ssh_password="123"
    )
    router_AR10 = Router(
        device_type="Huawei_AR2220",
        device_name="AR10",
        manage_ip="172.16.10.2",
        auth_type="snmp",
        snmp_port=161,
        snmp_password="public",
        ssh_port=22,
        ssh_username="admin",
        ssh_password="123"
    )
    router_AR11 = Router(
        device_type="Huawei_AR2220",
        device_name="AR11",
        manage_ip="10.0.18.1",
        auth_type="snmp",
        snmp_port=161,
        snmp_password="public",
        ssh_port=22,
        ssh_username="admin",
        ssh_password="123"
    )
    firewall_FW11 = Firewall(
        device_type="Huawei_USG6000V",
        device_name="FW11",
        manage_ip="192.168.12.12",
        auth_type="snmp",
        snmp_port=161,
        snmp_password="Tontron@1169!",
        ssh_port=22,
        ssh_username="sshadmin",
        ssh_password="Tontron@1169"
    )
    net_device_list = [router_AR5, router_AR8, router_AR9, router_AR10, router_AR11, firewall_FW11]
    return True


# 保存设备到jsonl文件
def save_all_net_device():
    with open('net_device_saved.jsonl', 'w') as f:
        for item in net_device_list:
            # 使用jsonpickle将对象转换为JSON字符串
            json_str = jsonpickle.encode(item)
            f.write(json_str + '\n')
    pass


def read_all_net_device():
    global route_detail_list
    global net_device_list
    route_detail_list = []
    net_device_list = []
    with open('net_device_saved.jsonl', 'r') as f:
        for line in f:
            obj_ = jsonpickle.decode(line)
            obj_.print_device_details()
            net_device_list.append(obj_)
            time.sleep(2)
    pass


# 初始化所有网络设备，连接、测试
def init_net_device():
    input("按任意键开始")
    print("--------初始化网络设备中---------")

    # 连接eNSP获取实时设备信息
    try:
        create_all_net_device()
        # 存储设备信息到本地
        # 保存到jsonl
        save_all_net_device()
    except Exception as e:
        # 从本地读取设备信息
        print(f"连接eNSP获取实时设备信息失败，开始从本地读取设备信息: {str(e)}")
        read_all_net_device()

    return True


if __name__ == '__main__':
    global origin_ip
    global destination_ip
    while True:
        init_net_device()
        wait_for_signal = input("\033[1;31m信息收集完毕，系统待命中，按任意键开始使用,输入N退出\033[0m")
        if wait_for_signal == "N" or wait_for_signal == "n":
            break
        while True:
            origin_ip = input("\033[1;31m输入源地址，请输入，格式示例：192.168.10.1/24\033[0m    \n")
            if validate_ip_with_subnet(origin_ip):
                origin_ip = origin_ip.split("/")[0] + "/" + cidr_to_netmask(origin_ip.split("/")[1])
                break
            else:
                print("源地址格式有误，示例：192.168.10.1/24")
        while True:
            destination_ip = input("\033[1;31m输入目地址，请重新输入，格式示例：192.168.10.1/24\033[0m    \n")
            if validate_ip_with_subnet(destination_ip):
                destination_ip = destination_ip.split("/")[0] + "/" + cidr_to_netmask(destination_ip.split("/")[1])
                break
            else:
                print("源地址格式有误，示例：192.168.10.1/24")
        print("源：", origin_ip)
        print("目：", destination_ip)
        print("开始分析本地路由情况")
        #分析反向路由
        if check_route(destination_ip, origin_ip):
            route_detail_string = ""
            route_detail_string = route_detail_string + "[" + destination_ip + "] ---->"
            for route in route_detail_list:
                route_detail_string = route_detail_string + "(" + (
                    route["interface_in"].name.replace("Interface", "").split(",")[-1] + ":" + route[
                        "interface_in"].ip if (
                        isinstance(route["interface_in"], Interface)) else "None") + ")" + "[" + route[
                                          "device"].device_name + "]" + "(" + (
                                          route["interface_in"].name.replace("Interface", "").split(",")[-1] + ":" +
                                          route["interface_out"].ip if (
                                              isinstance(route["interface_out"],
                                                         Interface)) else "None") + ")" + "---->"
            route_detail_string = route_detail_string + "[" + origin_ip + "] "
            print("\n反向路由存在，路由信息为：\n{}\n".format(route_detail_string))
            route_detail_list = []
            pass
        else:
            print("反向路由不存在，上报管理员，流程结束")
            route_detail_list = []
            net_device_list = []
            continue
        if check_route(origin_ip, destination_ip):
            route_detail_string = ""
            route_detail_string = route_detail_string + "[" + origin_ip + "] ---->"
            for route in route_detail_list:
                route_detail_string = route_detail_string + "(" + (route["interface_in"].name.replace("Interface","").split(",")[-1] + ":" +route["interface_in"].ip if (
                    isinstance(route["interface_in"], Interface)) else "None") + ")" + "[" + route[
                                          "device"].device_name + "]" + "(" + (route["interface_in"].name.replace("Interface","").split(",")[-1] + ":" + route["interface_out"].ip if (
                    isinstance(route["interface_out"], Interface)) else "None") + ")" + "---->"
            route_detail_string = route_detail_string + "[" + destination_ip + "] "
            print("\n正向路由存在，路由信息为：\n{}\n".format(route_detail_string))
            secure_route_list = []
            for route in route_detail_list:
                net_device = route["device"]
                if isinstance(net_device, Firewall):
                    secure_route_list.append(route)
            user_input = input("\033[1;31m是否开始执行安全策略分析，Y 需要 N 不需要 请输入：\033[0m")
            if user_input == "N" or user_input == "n":
                print("无需分析，当前任务已完成")
                continue
            firewall_need_new_policy, firewall_need_edit_policy = check_secure_policy(secure_route_list, origin_ip,
                                                                                      destination_ip)
            if (len(firewall_need_new_policy) + len(firewall_need_edit_policy)) == 0:
                print("网络策略已存在，无需变更，流程结束")
                route_detail_list = []
                net_device_list = []
                continue
            else:
                # 生成网络策略
                created_secure_policy = []
                changed_secure_policy = []
                created_secure_policy_print =[]
                changed_secure_policy_print =[]
                # 遍历创建需要新增的策略
                for firewall_need_new_policy_element in firewall_need_new_policy:
                    created_secure_policy_command = firewall_need_new_policy_element["device"].create_secure_policy(
                        firewall_need_new_policy_element, origin_ip,
                        destination_ip)
                    created_secure_policy.append({"device": firewall_need_new_policy_element["device"],
                                                  "commands": created_secure_policy_command})
                    created_secure_policy_print.append({"device": firewall_need_new_policy_element["device"].device_name,
                                                "commands": created_secure_policy_command})
                # 遍历创建需要修改的策略
                for firewall_need_edit_policy_element in firewall_need_edit_policy:
                    changed_secure_policy_command = firewall_need_edit_policy_element["secure_route_detail"]["device"] \
                        .edit_secure_policy(firewall_need_edit_policy_element["old_policy"],
                                            firewall_need_edit_policy_element["edit_key"],
                                            origin_ip, destination_ip)
                    changed_secure_policy.append(
                        {"device": firewall_need_edit_policy_element["secure_route_detail"]["device"],
                         "commands": changed_secure_policy_command,
                         "edit_key": firewall_need_edit_policy_element["edit_key"]})
                    changed_secure_policy_print.append(
                        {"device": firewall_need_edit_policy_element["secure_route_detail"]["device"].device_name,
                         "commands": changed_secure_policy_command,
                         "edit_key": firewall_need_edit_policy_element["edit_key"]})

                if len(created_secure_policy) > 0:
                    print("生成的新策略为{}".format(str(created_secure_policy_print)))
                    confirm_secure_policy_flag_created, new_secure_policy = confirm_secure_policy(created_secure_policy)
                else:
                    confirm_secure_policy_flag_created = True
                    new_secure_policy = []
                if len(changed_secure_policy) > 0:
                    print("生成的修改策略为{}".format(str(changed_secure_policy_print)))
                    confirm_secure_policy_flag_changed, edit_secure_policy = confirm_secure_policy(
                        changed_secure_policy)
                else:
                    confirm_secure_policy_flag_changed = True
                    edit_secure_policy = []
                secure_policy = ["dis this"]
                if confirm_secure_policy_flag_created and confirm_secure_policy_flag_changed:
                    #  执行网络策略-新增
                    for index in range(len(new_secure_policy)):
                        print("开始执行第{}条新增策略".format(index + 1))
                        execute_response = new_secure_policy[index]["device"].execute_secure_policy_command(
                            new_secure_policy[index]["commands"], terminator=r"return")
                        # 确认网络策略执行情况（打印策略列表）
                        confirm_execute_result(execute_response)
                        # 确认流量流通情况
                        confirm_network_traffic()
                    # 执行网络策略-修改
                    for index in range(len(edit_secure_policy)):
                        print("开始执行第{}条修改策略".format(index + 1))
                        execute_response = edit_secure_policy[index]["device"].execute_secure_policy_command(
                            edit_secure_policy[index]["commands"], terminator=r"return")
                        # 确认网络策略执行情况（打印策略列表）
                        confirm_execute_result(execute_response)
                        # 确认流量流通情况
                        confirm_network_traffic()
                else:
                    print("网络策略生成失败，请转人工处理")
                    route_detail_list = []
                    net_device_list = []
                    continue
        else:
            print("正向路由不存在，上报管理员，流程结束")
            route_detail_list = []
            net_device_list = []
            continue

    pass
